# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import bisect
import copy
import logging
from collections import defaultdict
from typing import List

import numpy as np
import paddle

from paddlemix.models.internvl2.constants import (
    IMG_CONTEXT_TOKEN,
    IMG_END_TOKEN,
    IMG_START_TOKEN,
)

from .trainer_utils import LabelSmoother

IGNORE_TOKEN_ID = LabelSmoother.ignore_index
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def is_dist_avail_and_initialized():
    if not paddle.distributed.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return paddle.distributed.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return paddle.distributed.get_rank()


class PackedDataset(paddle.io.IterableDataset):
    def __init__(
        self,
        tokenizer,
        data_rank,
        data_world_size,
        datasets: List,
        dataset_weight: List[int] = None,
        num_images_expected: int = 6,
        max_packed_tokens: int = 32768,
        max_buffer_size: int = 100,
        log_freq: int = 1000000,
        strict_mode: bool = False,
        debug_mode: bool = False,
        replacement: bool = True,
        allow_overflow: bool = True,
        allow_empty_data: bool = False,
        allow_deduplicated_ds_name: bool = False,
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.data_rank = data_rank
        self.data_world_size = data_world_size
        self.datasets = datasets
        self.num_images_expected = num_images_expected
        self.max_buffer_size = max_buffer_size
        self.log_freq = log_freq
        self.strict_mode = strict_mode
        self.debug_mode = debug_mode
        self.replacement = replacement
        self.allow_overflow = allow_overflow
        self.allow_empty_data = allow_empty_data
        self.max_packed_tokens = max_packed_tokens
        self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
        self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
        assert self.img_start_token_id != self.tokenizer.unk_token_id
        assert self.img_token_id != self.tokenizer.unk_token_id
        assert self.img_end_token_id != self.tokenizer.unk_token_id
        if dataset_weight is None:
            dataset_weight = [1] * len(datasets)
        self.dataset_type = [d.dataset_type for d in self.datasets]
        self.datasets_orig = datasets
        self.dataset_weight_orig = [(w / sum(dataset_weight)) for w in dataset_weight]
        self.datasets = [ds for ds in self.datasets_orig]
        self.dataset_weight = [w for w in self.dataset_weight_orig]
        self.worker_id = None
        self.worker_state_key = None
        self.dataset_iter_list = None
        self._state_dict = {"sample_info": {d.ds_name: (0) for d in self.datasets}}
        self.worker_custom_infos = None
        ds_name_list = [d.ds_name for d in self.datasets]
        if not allow_deduplicated_ds_name:
            assert len(ds_name_list) == len(set(ds_name_list)), f"deduplicated ds_name: {ds_name_list}"
        for ds in self.datasets:
            if ds.max_num_images > self.num_images_expected:
                logger.warning(
                    f"ds.max_num_images={ds.max_num_images!r} of {ds.ds_name} is larger than self.num_images_expected={self.num_images_expected!r}"
                )
                ds.max_num_images = num_images_expected
            if ds.max_tokens > self.max_packed_tokens:
                logger.warning(
                    f"ds.max_tokens={ds.max_tokens!r} of {ds.ds_name} is larger than self.max_packed_tokens={self.max_packed_tokens!r}"
                )
                ds.max_tokens = self.max_packed_tokens
            self._state_dict[ds.ds_name] = {}
        if get_rank() == 0:
            logger.info(
                f"Loaded dataset to pack: {ds_name_list}, self.num_images_expected={self.num_images_expected!r}, self.max_packed_tokens={self.max_packed_tokens!r}, self.replacement={self.replacement!r}, self.allow_overflow={self.allow_overflow!r}"
            )
            temp = []
            for ds, ds_w in zip(self.datasets, self.dataset_weight):
                temp.append(f"{ds.ds_name:<25}: {ds_w * 100:.2f}%")
            temp = "\n".join(temp)
            logger.info(f"Sampling prob for each dataset:\n{temp}")
        if self.allow_empty_data:
            logger.warning("allow_empty_data is enabled, note that empty data may be generated!")

    def load_state_dict(self, state_dict, custom_infos=None):
        self.worker_custom_infos = custom_infos
        self._state_dict.update(state_dict)
        for ds in self.datasets:
            if ds.ds_name in self._state_dict:
                ds.set_state_dict(state_dict=self._state_dict[ds.ds_name])
                logger.info(f"ds.ds_name={ds.ds_name!r} is resumed.")
            else:
                logger.warning(f"ds.ds_name={ds.ds_name!r} is not resumed.")

    def _should_log(self):
        worker_id = 0 if paddle.io.get_worker_info() is None else paddle.io.get_worker_info().id
        num_workers = 1 if paddle.io.get_worker_info() is None else paddle.io.get_worker_info().num_workers
        worker_id = num_workers * get_rank() + worker_id
        num_workers = num_workers * get_world_size()
        return worker_id == 0

    def next_data(self, current_dataset_idx):
        while True:
            try:
                current_sample = next(self.dataset_iter_list[current_dataset_idx])
                break
            except StopIteration:
                if self.replacement:
                    try:
                        self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx])
                        current_sample = next(self.dataset_iter_list[current_dataset_idx])
                        break
                    except:
                        self.datasets.pop(current_dataset_idx)
                        self.dataset_iter_list.pop(current_dataset_idx)
                        self.dataset_weight.pop(current_dataset_idx)
                        if len(self.datasets) == 0:
                            raise StopIteration
                        current_dataset_idx = np.random.choice(len(self.datasets))
                else:
                    self.datasets.pop(current_dataset_idx)
                    self.dataset_iter_list.pop(current_dataset_idx)
                    self.dataset_weight.pop(current_dataset_idx)
                    if len(self.datasets) == 0:
                        raise StopIteration
                    current_dataset_idx = np.random.choice(len(self.datasets))
            except:
                logger.error("Unexpected error!")
                if len(self.datasets) == 0:
                    raise StopIteration
                current_dataset_idx = np.random.choice(len(self.datasets))
        current_ds_name = self.datasets[current_dataset_idx].ds_name
        current_sample["type_ids"] = paddle.zeros_like(x=current_sample["input_ids"]) + current_dataset_idx
        if self.worker_state_key not in self._state_dict[current_ds_name]:
            self._state_dict[current_ds_name][self.worker_state_key] = {}
        meta_info = current_sample.pop("meta_info", {})
        self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info)
        self._state_dict["sample_info"][self.datasets[current_dataset_idx].ds_name] += 1
        return current_sample

    def find_buffer(self, buffer_list, new_sample):
        find = False
        find_idx = -1
        num_images_current = new_sample["pixel_values"].shape[0]
        for buffer_idx, buffer in enumerate(buffer_list):
            num_images_buffer = buffer["pixel_values"].shape[0]
            if num_images_buffer + num_images_current <= self.num_images_expected:
                num_merged_tokens = new_sample["input_ids"].shape[0] + buffer["input_ids"].shape[0]
                if num_merged_tokens <= self.max_packed_tokens:
                    find = True
                    find_idx = buffer_idx
                    break
                if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2:
                    find = True
                    find_idx = buffer_idx
        if find:
            return buffer_list.pop(find_idx)
        return None

    def update_buffer(self, buffer, new_sample):
        if buffer is None:
            new_sample["data_index"] = paddle.zeros_like(x=new_sample["input_ids"])
            return new_sample
        new_sample["data_index"] = paddle.ones_like(x=new_sample["input_ids"]) + buffer["data_index"][-1].item()
        assert buffer.keys() == new_sample.keys()
        for k in buffer:
            buffer[k] = paddle.concat(x=[buffer[k], new_sample[k]])
        return buffer

    @staticmethod
    def check_valid(sample_to_check, min_active_tokens_ratio=1 / 256):
        num_ignore_tokens = (sample_to_check["labels"] == IGNORE_TOKEN_ID).sum()
        num_tokens = sample_to_check["labels"].size
        return 1 - num_ignore_tokens / num_tokens > min_active_tokens_ratio

    @staticmethod
    def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id):
        if buffer["input_ids"].shape[0] <= max_tokens:
            return [buffer]

        def _image_is_splitted(input_ids, cut_idx):
            is_image_start = input_ids[cut_idx].item() == img_start_token_id
            is_image_token = input_ids[cut_idx].item() == img_token_id
            is_image_end = input_ids[cut_idx].item() == img_end_token_id
            return is_image_start or is_image_token or is_image_end

        def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx):
            assert (right_idx is None) == (right_img_idx is None)
            left_sample = {}
            right_sample = {} if right_idx is not None else None
            for k in sample_to_split:
                if k in ["input_ids", "labels", "attention_mask", "position_ids", "data_index", "type_ids"]:
                    left_sample[k] = sample_to_split[k][:left_idx]
                    if right_sample is not None:
                        right_sample[k] = sample_to_split[k][right_idx:]
                elif k in ["pixel_values", "image_flags"]:
                    left_sample[k] = sample_to_split[k][:left_img_idx]
                    if right_sample is not None:
                        right_sample[k] = sample_to_split[k][right_img_idx:]
                else:
                    raise NotImplementedError(f"find unsupported keys: {k} from {sample_to_split.keys()}")
            return left_sample, right_sample

        splitted_buffer = []
        while buffer["input_ids"].shape[0] > max_tokens:
            img_start_idx_list = (buffer["input_ids"] == img_start_token_id).nonzero().squeeze(axis=1).tolist()
            img_end_idx_list = (buffer["input_ids"] == img_end_token_id).nonzero().squeeze(axis=1).tolist()
            assert len(img_start_idx_list) == len(img_end_idx_list)
            if _image_is_splitted(buffer["input_ids"], max_tokens):
                cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens)
                if buffer["input_ids"][max_tokens] == img_start_token_id:
                    assert max_tokens == img_start_idx_list[cut_idx]
                    cut_left_idx = img_start_idx_list[cut_idx]
                    cut_left_img_idx = cut_idx
                else:
                    cut_left_idx = img_start_idx_list[cut_idx - 1]
                    cut_left_img_idx = cut_idx - 1
                cut_right_idx = cut_left_idx
                cut_right_img_idx = cut_left_img_idx
            else:
                cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens)
                if cut_img_idx < len(img_start_idx_list):
                    cut_right_idx = img_start_idx_list[cut_img_idx]
                    cut_right_img_idx = cut_img_idx
                else:
                    cut_right_idx = None
                    cut_right_img_idx = None
                cut_left_idx = max_tokens
                cut_left_img_idx = (
                    cut_right_img_idx if cut_right_img_idx is not None else buffer["pixel_values"].shape[0]
                )
            left, right = _split(
                sample_to_split=buffer,
                left_idx=cut_left_idx,
                left_img_idx=cut_left_img_idx,
                right_idx=cut_right_idx,
                right_img_idx=cut_right_img_idx,
            )
            assert (
                (left["input_ids"] == img_end_token_id).sum()
                == (left["input_ids"] == img_start_token_id).sum()
                == left["pixel_values"].shape[0]
            )
            if right is not None:
                assert (
                    (right["input_ids"] == img_end_token_id).sum()
                    == (right["input_ids"] == img_start_token_id).sum()
                    == right["pixel_values"].shape[0]
                )
            if left["pixel_values"].shape[0] >= 1 and PackedDataset.check_valid(left):
                splitted_buffer.append(left)
            if right is None or right["pixel_values"].shape[0] == 0:
                break
            buffer = right
            if buffer["input_ids"].shape[0] <= max_tokens and PackedDataset.check_valid(buffer):
                splitted_buffer.append(buffer)
                break
        logger.debug(f"split a sample into {len(splitted_buffer)} samples, current max_tokens={max_tokens}")
        return splitted_buffer

    def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer):
        splitted_buffer = PackedDataset.split_buffer(
            buffer=buffer,
            max_tokens=self.max_packed_tokens,
            img_start_token_id=self.img_start_token_id,
            img_token_id=self.img_token_id,
            img_end_token_id=self.img_end_token_id,
        )
        for each_buffer in splitted_buffer:
            if each_buffer["pixel_values"].shape[0] > self.num_images_expected:
                logger.error(
                    f"Find a sample with {each_buffer['pixel_values'].shape[0]} images, which exceeds {self.num_images_expected}"
                )
                continue
            if each_buffer["input_ids"].shape[0] >= self.max_packed_tokens:
                assert each_buffer["input_ids"].shape[0] == self.max_packed_tokens
                buffer_max_len_list.append(each_buffer)
                continue
            find_idx = len(buffer_list)
            num_images_new_sample = each_buffer["pixel_values"].shape[0]
            for buffer_idx in range(len(buffer_list)):
                if buffer_list[buffer_idx]["pixel_values"].shape[0] < num_images_new_sample:
                    find_idx = buffer_idx
                    break
            buffer_list.insert(find_idx, each_buffer)
        for i in range(1, len(buffer_list)):
            assert buffer_list[i - 1]["pixel_values"].shape[0] >= buffer_list[i]["pixel_values"].shape[0]
        return buffer_list, buffer_max_len_list

    def pad_buffer(self, buffer):
        if buffer["pixel_values"].shape[0] == self.num_images_expected:
            return buffer
        num_pad_images = self.num_images_expected - buffer["pixel_values"].shape[0]
        pad_images = paddle.stack(x=[paddle.zeros_like(x=buffer["pixel_values"][0]) for _ in range(num_pad_images)])
        pad_image_flags = paddle.to_tensor(data=[0] * num_pad_images, dtype="int64")
        buffer["pixel_values"] = paddle.concat(x=[buffer["pixel_values"], pad_images])
        buffer["image_flags"] = paddle.concat(x=[buffer["image_flags"], pad_image_flags])
        return buffer

    def postprocess_buffer(self, buffer, custom_infos=None):
        buffer["worker_state_key"] = self.worker_state_key
        buffer["worker_state_dict"] = self._state_dict
        if custom_infos is not None:
            buffer["custom_infos"] = {self.worker_state_key: copy.deepcopy(custom_infos)}
        return buffer

    def print_log(self, iter_idx, buffer_list):
        if iter_idx % self.log_freq != 0:
            return
        if self._should_log():
            logger.info(
                f"iter_idx={iter_idx!r}, len(buffer_list)={len(buffer_list)!r}, {self._state_dict['sample_info']}"
            )

    def __iter__(self):
        iter_idx = 0
        buffer_list = []
        buffer_max_len_list = []
        if self._should_log():
            logger.info(f"Begin to iter, len(buffer_list)={len(buffer_list)!r}")
        worker_id = 0 if paddle.io.get_worker_info() is None else paddle.io.get_worker_info().id
        num_workers = 1 if paddle.io.get_worker_info() is None else paddle.io.get_worker_info().num_workers
        worker_id = num_workers * self.data_rank + worker_id
        num_workers = num_workers * self.data_world_size
        rng = np.random.default_rng(seed=worker_id)
        self.worker_id = worker_id
        self.worker_state_key = f"work_state_{self.worker_id}"
        self.datasets = [d for d in self.datasets_orig]
        self.dataset_weight = [w for w in self.dataset_weight_orig]
        self.dataset_iter_list = [iter(d) for d in self.datasets]
        for ds in self.datasets:
            ds.worker_id = worker_id
            ds.worker_state_key = f"work_state_{self.worker_id}"
            ds.num_workers = num_workers
            if self._should_log() and worker_id == 0:
                logger.info(f"set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}")
        if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos:
            custom_infos = self.worker_custom_infos[self.worker_state_key]
            if "buffer_list" in custom_infos and isinstance(custom_infos["buffer_list"], list):
                buffer_list = custom_infos["buffer_list"]
                if self._should_log() and worker_id == 0:
                    logger.info(
                        f"[{self.worker_state_key}] load buffer list --> len(buffer_list)={len(buffer_list)!r}"
                    )
            self.worker_custom_infos = None
        logger.debug(f"{self.__class__.__name__} Rank {self.data_rank} Worker {worker_id} begin to load data")
        while True:
            self.dataset_weight = [(w / sum(self.dataset_weight)) for w in self.dataset_weight]
            current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight)
            try:
                current_sample = self.next_data(current_dataset_idx)
            except:
                logger.info(
                    f"All datasets are exhausted, begin to empty the buffer_list (len(buffer_list)={len(buffer_list)!r})"
                )
                while len(buffer_list) > 0:
                    if self.strict_mode:
                        yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)))
                    else:
                        yield self.postprocess_buffer(buffer_list.pop(0))
                logger.info(f"buffer_list is empty! (len(buffer_list)={len(buffer_list)!r})")
                return
            buffer = self.find_buffer(buffer_list, current_sample)
            buffer = self.update_buffer(buffer, current_sample)
            buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer)
            while len(buffer_max_len_list) > 0:
                if buffer_max_len_list[0]["pixel_values"].shape[0] != self.max_packed_tokens:
                    logger.debug(
                        f"num tokens of a buffer exceed self.max_packed_tokens={self.max_packed_tokens!r}, yield a sample with {buffer_max_len_list[0]['pixel_values'].shape[0]} images"
                    )
                if self.strict_mode and buffer_max_len_list[0]["pixel_values"].shape[0] != self.num_images_expected:
                    yield self.postprocess_buffer(
                        self.pad_buffer(buffer_max_len_list.pop(0)), {"buffer_list": buffer_list}
                    )
                else:
                    yield self.postprocess_buffer(buffer_max_len_list.pop(0), {"buffer_list": buffer_list})
            while len(buffer_list) > 0 and buffer_list[0]["pixel_values"].shape[0] > self.num_images_expected:
                logger.error(
                    f"num images of a buffer ({buffer_list[0]['pixel_values'].shape[0]}) is larger than num_images_expected({self.num_images_expected})"
                )
                buffer_list.pop(0)
            while len(buffer_list) > 0 and buffer_list[0]["pixel_values"].shape[0] == self.num_images_expected:
                if self.debug_mode:
                    debug_data = self.postprocess_buffer(buffer_list.pop(0), {"buffer_list": buffer_list})
                    while True:
                        yield debug_data.copy()
                yield self.postprocess_buffer(buffer_list.pop(0), {"buffer_list": buffer_list})
            while len(buffer_list) > self.max_buffer_size:
                logger.debug(
                    f"Failed to pack data to exactly {self.num_images_expected} images, yield a data sample with {buffer_list[0]['pixel_values'].shape[0]} images."
                )
                if self.strict_mode:
                    yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {"buffer_list": buffer_list})
                else:
                    yield self.postprocess_buffer(buffer_list.pop(0), {"buffer_list": buffer_list})
            self.print_log(iter_idx=iter_idx, buffer_list=buffer_list)
            iter_idx += 1

    @staticmethod
    def get_cu_seqlens_and_indexes(
        data_index: paddle.Tensor, input_ids: paddle.Tensor, labels: paddle.Tensor, len2weight: callable
    ):
        indexes = []
        cu_seqlens = [0]
        loss_weight = []
        start = data_index.min()
        end = data_index.max() + 1
        for i in range(start, end):
            num_tokens = (data_index == i).sum().item()
            indexes.extend(list(range(num_tokens)))
            cu_seqlens.append(cu_seqlens[-1] + num_tokens)
            assert num_tokens > 0
            curr_data_index = data_index[cu_seqlens[-2] : cu_seqlens[-2] + num_tokens]
            assert (curr_data_index == i).astype("bool").all(), data_index
            curr_labels = labels[cu_seqlens[-2] : cu_seqlens[-2] + num_tokens]
            num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item()
            loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens)
        assert (
            len(indexes) == data_index.shape[0]
        ), f"len(indexes)={len(indexes)!r}, data_index.size(0)={data_index.shape[0]!r}"
        loss_weight = paddle.to_tensor(data=loss_weight, dtype="float32")
        return cu_seqlens, indexes, loss_weight


WARNING_CNT = defaultdict(int)


def packed_collate_fn(
    features,
    data_collator,
    len2weight: callable,
    max_item_length: int,
    micro_num: int = 1,
    loss_reduction_all_gather: bool = False,
    pad_id: int = 0,
):
    if not isinstance(features, list):
        features = [features]
    if len(features) > micro_num:
        raise NotImplementedError(f"len(features)={len(features)!r} > micro_num={micro_num!r}")
    if len(features) < micro_num and WARNING_CNT["micro_num_warning"] < 5:
        logger.warning(
            f"len(features)={len(features)!r} > micro_num={micro_num!r}, the features will be padded to satisfy micro_num requirement"
        )
        WARNING_CNT["micro_num_warning"] += 1
    num_features = len(features)
    while len(features) < micro_num:
        features.append(copy.deepcopy(features[0]))
        features[-1]["labels"] = paddle.full_like(x=features[-1]["labels"], fill_value=IGNORE_TOKEN_ID)
    indexes = []
    cu_seqlens = []
    cu_num_images_list = [0]
    worker_state_key_list = []
    worker_state_dict_list = []
    worker_state_custom_infos_list = []
    batch_lens = [tuple(feat["input_ids"].shape) for feat in features]
    max_item_length = max_item_length or max(batch_lens)[0]
    num_samples = 0
    num_padding_tokens = 0
    for feat_idx, feat in enumerate(features):
        data_index = feat.pop("data_index")
        curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes(
            data_index=data_index, input_ids=feat["input_ids"], labels=feat["labels"], len2weight=len2weight
        )
        feat["loss_weight"] = curr_loss_weight
        if feat_idx < num_features:
            num_samples += len(curr_cu_seqlens) - 1
        if curr_cu_seqlens[-1] < max_item_length:
            curr_cu_seqlens.append(max_item_length)
            curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2])))
        indexes.append(paddle.to_tensor(data=curr_indexes, dtype="int64"))
        cu_seqlens.append(paddle.to_tensor(data=curr_cu_seqlens, dtype="int32"))
        worker_state_key_list.append(feat.pop("worker_state_key"))
        worker_state_dict_list.append(feat.pop("worker_state_dict"))
        worker_state_custom_infos_list.append(feat.pop("custom_infos", None))
        num_padding_tokens += max_item_length - feat["input_ids"].shape[0]
        cu_num_images_list.append(cu_num_images_list[-1] + feat["pixel_values"].shape[0])
    batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id)
    batch["loss_weight"] = paddle.where(
        condition=batch["labels"] == IGNORE_TOKEN_ID, x=0, y=batch["loss_weight"]
    ).tolist()
    batch["attention_mask"] = paddle.stack(x=cu_seqlens)
    batch["loss_reduction_all_gather"] = loss_reduction_all_gather
    batch["statistics"] = paddle.to_tensor(
        data=[num_samples, num_padding_tokens, batch["image_flags"].size - batch["image_flags"].sum().item()],
        dtype="int64",
    )
    batch.pop("type_ids")
    return batch
